//Rainbow.scala
package mycore

import chisel3._
import chisel3.util._

class Rainbow extends Module{
  val io = IO(new Bundle{

    val ENABLE = Input(Bool())
    val InstFetch = new InstFetch
    val DataLoad = new DataLoad
    val DataStore = new DataStore

    //difftest
    val InstrCommit = Flipped(new DiffInstrCommitIO)
    val ArchIntRegState = Flipped(new DiffArchIntRegStateIO)
    val CSRState = Flipped(new DiffCSRStateIO)
    val TrapEvent = Flipped(new DiffTrapEventIO)
    val ArchFpRegState = Flipped(new DiffArchFpRegStateIO)
    val ArchEvent = Flipped(new DiffArchEventIO)

    val isSoC = Input(Bool())
  })

  val Fetch       = Module(new Fetch)
  val InstDecode  = Module(new InstDecode)
  val Execute     = Module(new Execute)
  val Memory      = Module(new Memory)
  val WriteBack   = Module(new WriteBack)

  Fetch.io.isSoC := io.isSoC
  WriteBack.io.isSoC := io.isSoC

  //Linear data flow
  Fetch.io.ifid       <> InstDecode.io.ifid
  InstDecode.io.idex  <> Execute.io.idex
  Execute.io.exmem    <> Memory.io.exmem
  Memory.io.memwb     <> WriteBack.io.memwb

  Execute.io.cancel <> InstDecode.io.cancel
  Execute.io.cancel <> Fetch.io.cancel

  //redest
  Execute.io.redesMem <> Memory.io.redesMem
  Execute.io.redesWb  <> WriteBack.io.wbdata

  Execute.io.nextPC   <> Fetch.io.nextPC

  Execute.io.TimeOver <> Memory.io.TimeOver

  InstDecode.io.wbdata <> WriteBack.io.wbdata


/* 问题：当mret修改完mstatus之后，马上触发计时器中断，mepc会错误地记录误取指令的值
   解决：紧凑技术

      IF    ID    EX    MEM    WB
  t:   ----|-----|-----|------|-----
        4c   mret   44    40     38

  t+1: ----|-----|-----|------|-----
        48    4c   mret   44     40

  t+2: ----|_____|-----|------|-----
        a0    48   mret   44     40

  t+3: ----|-----|-----|------|-----
        a4    a0   mret   44     40

  t+1 -> EX.io.cancel的第一拍
  t+2 -> EX.io.cancel的第二拍

*/

  val Campact = Execute.io.cancel
  val CampactCnt = RegInit(1.U(4.W))

  when(io.ENABLE){
    when(CampactCnt === 3.U) { 
      CampactCnt := 1.U 
    }.elsewhen(Campact){
      CampactCnt := CampactCnt + 1.U
    }
  }

  val CampactCycleOne = Campact && (CampactCnt === 1.U)
  val CampactCycleTwo = Campact && (CampactCnt === 2.U)
  val CampactCycleThree = Campact && (CampactCnt === 3.U)

  Fetch.io.cancel       := CampactCycleOne
  InstDecode.io.cancel  := CampactCycleOne

  Fetch.io.ENABLE       := io.ENABLE
  InstDecode.io.ENABLE  := io.ENABLE
  Execute.io.ENABLE     := Mux(CampactCycleOne || CampactCycleTwo, false.B, io.ENABLE)
  Memory.io.ENABLE      := Mux(CampactCycleOne || CampactCycleTwo, false.B, io.ENABLE)
  WriteBack.io.ENABLE   := Mux(CampactCycleOne || CampactCycleTwo, false.B, io.ENABLE)

  InstDecode.io.wbdata := Mux(CampactCycleOne || CampactCycleTwo, 0.U.asTypeOf(new WBDATA), WriteBack.io.wbdata)
  Fetch.io.nextPC := Mux(CampactCycleOne, Execute.io.nextPC, 0.U.asTypeOf(new NEXTPC))

  //IO

  val LoadDataInCampact = RegInit(0.U(64.W))

  when(io.ENABLE){
    when(CampactCycleOne){
      LoadDataInCampact := io.DataLoad.data
    }.elsewhen(CampactCycleThree){
      LoadDataInCampact := 0.U
    }
  }

  io.InstFetch  <> Fetch.io.InstFetch

  //Even load can change machine state
  val LoadStoreBlocker = RegEnable(CampactCycleOne, false.B, io.ENABLE) || RegEnable(CampactCycleTwo, false.B, io.ENABLE)
  io.DataLoad.en := Memory.io.DataLoad.en && ~LoadStoreBlocker
  io.DataLoad.addr := Memory.io.DataLoad.addr
  Memory.io.DataLoad.data := Mux(LoadStoreBlocker, LoadDataInCampact, io.DataLoad.data)
  io.DataLoad.size := Memory.io.DataLoad.size

  io.DataStore.en := Memory.io.DataStore.en && ~LoadStoreBlocker
  io.DataStore.addr := Memory.io.DataStore.addr
  io.DataStore.data := Memory.io.DataStore.data
  io.DataStore.mask := Memory.io.DataStore.mask

  //difftest
  InstDecode.io.gpr   <> WriteBack.io.gpr
  io.InstrCommit      <> WriteBack.io.InstrCommit
  io.ArchIntRegState  <> WriteBack.io.ArchIntRegState
  io.CSRState         <> WriteBack.io.CSRState
  io.TrapEvent        <> WriteBack.io.TrapEvent
  io.ArchFpRegState   <> WriteBack.io.ArchFpRegState
  io.ArchEvent        <> WriteBack.io.ArchEvent

}



trait DifftestParameter {
}

trait DifftestWithClock {
  val clock  = Input(Clock())
}

trait DifftestWithCoreid {
  val coreid = Input(UInt(8.W))
}

trait DifftestWithIndex {
  val index = Input(UInt(8.W))
}

abstract class DifftestBundle extends Bundle
  with DifftestParameter
  with DifftestWithClock
  with DifftestWithCoreid

class DiffArchEventIO extends DifftestBundle {
  val intrNO = Input(UInt(32.W))
  val cause = Input(UInt(32.W))
  val exceptionPC = Input(UInt(64.W))
  val exceptionInst = Input(UInt(32.W))
}

class DiffInstrCommitIO extends DifftestBundle with DifftestWithIndex {
  val valid    = Input(Bool())
  val pc       = Input(UInt(64.W))
  val instr    = Input(UInt(32.W))
  val skip     = Input(Bool())
  val isRVC    = Input(Bool())
  val scFailed = Input(Bool())
  val wen      = Input(Bool())
  val wdata    = Input(UInt(64.W))
  val wdest    = Input(UInt(8.W))
}

class DiffTrapEventIO extends DifftestBundle {
  val valid    = Input(Bool())
  val code     = Input(UInt(3.W))
  val pc       = Input(UInt(64.W))
  val cycleCnt = Input(UInt(64.W))
  val instrCnt = Input(UInt(64.W))
}

class DiffCSRStateIO extends DifftestBundle {
  val priviledgeMode = Input(UInt(2.W))
  val mstatus = Input(UInt(64.W))
  val sstatus = Input(UInt(64.W))
  val mepc = Input(UInt(64.W))
  val sepc = Input(UInt(64.W))
  val mtval = Input(UInt(64.W))
  val stval = Input(UInt(64.W))
  val mtvec = Input(UInt(64.W))
  val stvec = Input(UInt(64.W))
  val mcause = Input(UInt(64.W))
  val scause = Input(UInt(64.W))
  val satp = Input(UInt(64.W))
  val mip = Input(UInt(64.W))
  val mie = Input(UInt(64.W))
  val mscratch = Input(UInt(64.W))
  val sscratch = Input(UInt(64.W))
  val mideleg = Input(UInt(64.W))
  val medeleg = Input(UInt(64.W))
}

class DiffArchIntRegStateIO extends DifftestBundle {
  val gpr = Input(Vec(32, UInt(64.W)))
}

class DiffArchFpRegStateIO extends DifftestBundle {
  val fpr  = Input(Vec(32, UInt(64.W)))
}

class DiffSbufferEventIO extends DifftestBundle {
  val sbufferResp = Input(Bool())
  val sbufferAddr = Input(UInt(64.W))
  val sbufferData = Input(Vec(64, UInt(8.W)))
  val sbufferMask = Input(UInt(64.W))
}

class DiffStoreEventIO extends DifftestBundle with DifftestWithIndex {
  val valid       = Input(Bool())
  val storeAddr   = Input(UInt(64.W))
  val storeData   = Input(UInt(64.W))
  val storeMask   = Input(UInt(8.W))
}

class DiffLoadEventIO extends DifftestBundle with DifftestWithIndex {
  val valid  = Input(Bool())
  val paddr  = Input(UInt(64.W))
  val opType = Input(UInt(8.W))
  val fuType = Input(UInt(8.W))
}

class DiffAtomicEventIO extends DifftestBundle {
  val atomicResp = Input(Bool())
  val atomicAddr = Input(UInt(64.W))
  val atomicData = Input(UInt(64.W))
  val atomicMask = Input(UInt(8.W))
  val atomicFuop = Input(UInt(8.W))
  val atomicOut  = Input(UInt(64.W))
}

class DiffPtwEventIO extends DifftestBundle {
  val ptwResp = Input(Bool())
  val ptwAddr = Input(UInt(64.W))
  val ptwData = Input(Vec(4, UInt(64.W)))
}

class DifftestArchEvent extends BlackBox {
  val io = IO(new DiffArchEventIO)
}

class DifftestInstrCommit extends BlackBox {
  val io = IO(new DiffInstrCommitIO)
}

class DifftestTrapEvent extends BlackBox {
  val io = IO(new DiffTrapEventIO)
}

class DifftestCSRState extends BlackBox {
  val io = IO(new DiffCSRStateIO)
}

class DifftestArchIntRegState extends BlackBox {
  val io = IO(new DiffArchIntRegStateIO)
}

class DifftestArchFpRegState extends BlackBox {
  val io = IO(new DiffArchFpRegStateIO)
}

class DifftestSbufferEvent extends BlackBox {
  val io = IO(new DiffSbufferEventIO)
}

class DifftestStoreEvent extends BlackBox {
  val io = IO(new DiffStoreEventIO)
}

class DifftestLoadEvent extends BlackBox {
  val io = IO(new DiffLoadEventIO)
}

class DifftestAtomicEvent extends BlackBox {
  val io = IO(new DiffAtomicEventIO)
}

class DifftestPtwEvent extends BlackBox {
  val io = IO(new DiffPtwEventIO)
}

// Difftest emulator top

// XiangShan log / perf ctrl, should be inited in SimTop IO
// If not needed, just ingore these signals
class PerfInfoIO extends Bundle {
  val clean = Input(Bool())
  val dump = Input(Bool())
}

class LogCtrlIO extends Bundle {
  val log_begin, log_end = Input(UInt(64.W))
  val log_level = Input(UInt(64.W)) // a cpp uint
}

// UART IO, if needed, should be inited in SimTop IO
// If not needed, just hardwire all output to 0
class UARTIO extends Bundle {
  val out = new Bundle {
    val valid = Output(Bool())
    val ch = Output(UInt(8.W))
  }
  val in = new Bundle {
    val valid = Output(Bool())
    val ch = Input(UInt(8.W))
  }
}

package object difftest {
  
}